#include "statsxx/machine_learning/NeuralNet.hpp"

// STL
#include <iostream>

// stats++
#include "statsxx/machine_learning/activation_functions.hpp" // activation_function::Logistic, activation_function::softplus


inline double NEURAL_NET::deriv(const int neuron_idx, const int inp_idx)
{
    // return the case that dx/dx
    if( neuron_idx == inp_idx )
    {
        return 1.0;
    }

    // calculate dy/dI
    double dzdI;

    activation_function::Logistic    logistic;
    activation_function::softplus    softplus;
    activation_function::tanh_skewed tanh_skewed;

    switch(this->m_neurons[neuron_idx].act_func)
    {
        case NEURON::ACTIVATION_FUNC::IDENTITY:
            dzdI = 1.0;
            break;
        case NEURON::ACTIVATION_FUNC::LOGISTIC:
            //activation_function::Logistic logistic;
            dzdI = logistic.df(logistic.inv(this->m_neurons[neuron_idx].m_output.front()));
            break;
        case NEURON::ACTIVATION_FUNC::TANH:
            dzdI = tanh_skewed.df( tanh_skewed.inv( this->m_neurons[neuron_idx].m_output.front() ) );
            break;
        case NEURON::ACTIVATION_FUNC::SOFTPLUS:
            //activation_function::softplus softplus;
            dzdI = softplus.df(softplus.inv(this->m_neurons[neuron_idx].m_output.front()));
            break;
        case NEURON::ACTIVATION_FUNC::SOFTMAX:
            std::cout << "Error in NEURON::activate_error(): Unsure how to calculate dsoftmax!" << std::endl;
            exit(0);
        default:
            break;
    }

    // calculate dI/dx = sum_i w_i*dy_i/dx
    double dIdx = 0.0;

    for( auto l : this->m_links_in[neuron_idx] )
    {
        dIdx += m_links[l].weight*deriv(this->m_links[l].source, inp_idx);
    }

    // return dz/dx = (dz/dI)*(dI/dx)
    return (dzdI*dIdx);
}


